{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Random forest classification\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SparkContext and SparkSession"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark import SparkContext\n",
    "sc = SparkContext(master = 'local')\n",
    "\n",
    "from pyspark.sql import SparkSession\n",
    "spark = SparkSession.builder \\\n",
    "          .appName(\"Python Spark SQL basic example\") \\\n",
    "          .config(\"spark.some.config.option\", \"some-value\") \\\n",
    "          .getOrCreate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Random forest tree with pyspark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---+---------+---------+---+\n",
      "|age|education|wantsMore|  y|\n",
      "+---+---------+---------+---+\n",
      "|<25|      low|      yes|  0|\n",
      "|<25|      low|      yes|  0|\n",
      "|<25|      low|      yes|  0|\n",
      "|<25|      low|      yes|  0|\n",
      "|<25|      low|      yes|  0|\n",
      "+---+---------+---------+---+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "cuse = spark.read.csv('data/cuse_binary.csv', header=True, inferSchema=True)\n",
    "cuse.show(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Process categorical columns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Categorical columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['age', 'education', 'wantsMore']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler\n",
    "from pyspark.ml import Pipeline\n",
    "\n",
    "categorical_columns = cuse.columns[:-1]\n",
    "categorical_columns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build StringIndexe stages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "stringindexer_stages = [StringIndexer(inputCol=c, outputCol='stringindexed_' + c) for c in categorical_columns]\n",
    "# encode label column and add it to stringindexer stages\n",
    "stringindexer_stages += [StringIndexer(inputCol='y', outputCol='label')]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build OneHotEncoder stages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "onehotencoder_stages = [OneHotEncoder(inputCol='stringindexed_' + c, outputCol='onehot_'+c) for c in categorical_columns]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build VectorAssembler stage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "feature_columns = ['onehot_' + c for c in categorical_columns]\n",
    "vectorassembler_stage = VectorAssembler(inputCols=feature_columns, outputCol='features')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build pipeline model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_stages = stringindexer_stages + onehotencoder_stages + [vectorassembler_stage]\n",
    "pipeline = Pipeline(stages=all_stages)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fit pipeline model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline_model = pipeline.fit(cuse)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transform data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------+----------------+----------------+-------------------+-----+\n",
      "|   onehot_age|onehot_education|onehot_wantsMore|           features|label|\n",
      "+-------------+----------------+----------------+-------------------+-----+\n",
      "|(3,[2],[1.0])|       (1,[],[])|   (1,[0],[1.0])|(5,[2,4],[1.0,1.0])|  0.0|\n",
      "|(3,[2],[1.0])|       (1,[],[])|   (1,[0],[1.0])|(5,[2,4],[1.0,1.0])|  0.0|\n",
      "|(3,[2],[1.0])|       (1,[],[])|   (1,[0],[1.0])|(5,[2,4],[1.0,1.0])|  0.0|\n",
      "|(3,[2],[1.0])|       (1,[],[])|   (1,[0],[1.0])|(5,[2,4],[1.0,1.0])|  0.0|\n",
      "|(3,[2],[1.0])|       (1,[],[])|   (1,[0],[1.0])|(5,[2,4],[1.0,1.0])|  0.0|\n",
      "+-------------+----------------+----------------+-------------------+-----+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "final_columns = feature_columns + ['features', 'label']\n",
    "cuse_df = pipeline_model.transform(cuse).select(final_columns)\n",
    "cuse_df.show(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Split data into training and test datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train, test = cuse_df.randomSplit([0.8, 0.2], seed=1234)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build cross-validation model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import RandomForestClassifier\n",
    "\n",
    "random_forest = RandomForestClassifier(featuresCol='features', labelCol='label')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameter grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.tuning import ParamGridBuilder\n",
    "\n",
    "param_grid = ParamGridBuilder().\\\n",
    "    addGrid(random_forest.maxDepth, [2, 3, 4]).\\\n",
    "    addGrid(random_forest.minInfoGain, [0.0, 0.1, 0.2, 0.3]).\\\n",
    "    build()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "\n",
    "evaluator = BinaryClassificationEvaluator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build cross-validation model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.tuning import CrossValidator\n",
    "\n",
    "crossvalidation = CrossValidator(estimator=random_forest, estimatorParamMaps=param_grid, evaluator=evaluator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fit cross-validation model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "crossvalidation_mod = crossvalidation.fit(cuse_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prediction on training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+----------------+----------------+---------+-----+--------------------+--------------------+----------+\n",
      "|onehot_age|onehot_education|onehot_wantsMore| features|label|       rawPrediction|         probability|prediction|\n",
      "+----------+----------------+----------------+---------+-----+--------------------+--------------------+----------+\n",
      "| (3,[],[])|       (1,[],[])|       (1,[],[])|(5,[],[])|  0.0|[9.61727693784312...|[0.48086384689215...|       1.0|\n",
      "| (3,[],[])|       (1,[],[])|       (1,[],[])|(5,[],[])|  0.0|[9.61727693784312...|[0.48086384689215...|       1.0|\n",
      "| (3,[],[])|       (1,[],[])|       (1,[],[])|(5,[],[])|  0.0|[9.61727693784312...|[0.48086384689215...|       1.0|\n",
      "| (3,[],[])|       (1,[],[])|       (1,[],[])|(5,[],[])|  0.0|[9.61727693784312...|[0.48086384689215...|       1.0|\n",
      "| (3,[],[])|       (1,[],[])|       (1,[],[])|(5,[],[])|  0.0|[9.61727693784312...|[0.48086384689215...|       1.0|\n",
      "+----------+----------------+----------------+---------+-----+--------------------+--------------------+----------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pred_train = crossvalidation_mod.transform(train)\n",
    "pred_train.show(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prediction on test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+----------------+----------------+---------+-----+--------------------+--------------------+----------+\n",
      "|onehot_age|onehot_education|onehot_wantsMore| features|label|       rawPrediction|         probability|prediction|\n",
      "+----------+----------------+----------------+---------+-----+--------------------+--------------------+----------+\n",
      "| (3,[],[])|       (1,[],[])|       (1,[],[])|(5,[],[])|  0.0|[9.61727693784312...|[0.48086384689215...|       1.0|\n",
      "| (3,[],[])|       (1,[],[])|       (1,[],[])|(5,[],[])|  0.0|[9.61727693784312...|[0.48086384689215...|       1.0|\n",
      "| (3,[],[])|       (1,[],[])|       (1,[],[])|(5,[],[])|  0.0|[9.61727693784312...|[0.48086384689215...|       1.0|\n",
      "| (3,[],[])|       (1,[],[])|       (1,[],[])|(5,[],[])|  0.0|[9.61727693784312...|[0.48086384689215...|       1.0|\n",
      "| (3,[],[])|       (1,[],[])|       (1,[],[])|(5,[],[])|  0.0|[9.61727693784312...|[0.48086384689215...|       1.0|\n",
      "+----------+----------------+----------------+---------+-----+--------------------+--------------------+----------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pred_test = crossvalidation_mod.transform(test)\n",
    "pred_test.show(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prediction performance\n",
    "\n",
    "We calculate the **Area under the Receiver Operating Characteristic curve**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy on training data (areaUnderROC):  0.681918715706039 \n",
      "Accuracy on training data (areaUnderROC):  0.6755505721350122\n"
     ]
    }
   ],
   "source": [
    "print('Accuracy on training data (areaUnderROC): ', evaluator.setMetricName('areaUnderROC').evaluate(pred_train), \"\\n\"\n",
    "     'Accuracy on training data (areaUnderROC): ', evaluator.setMetricName('areaUnderROC').evaluate(pred_test))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Confusion matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Confusion matrix from training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int,\n",
       "            {Row(label=0.0, prediction=0.0): 746,\n",
       "             Row(label=0.0, prediction=1.0): 167,\n",
       "             Row(label=1.0, prediction=0.0): 220,\n",
       "             Row(label=1.0, prediction=1.0): 194})"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label_pred_train = pred_train.select('label', 'prediction')\n",
    "label_pred_train.rdd.zipWithIndex().countByKey()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Confusion matrix from test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int,\n",
       "            {Row(label=0.0, prediction=0.0): 151,\n",
       "             Row(label=0.0, prediction=1.0): 36,\n",
       "             Row(label=1.0, prediction=0.0): 50,\n",
       "             Row(label=1.0, prediction=1.0): 43})"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label_pred_test = pred_test.select('label', 'prediction')\n",
    "label_pred_test.rdd.zipWithIndex().countByKey()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Best model and paramters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max depth:  4 \n",
      " min information gain:  0.0\n"
     ]
    }
   ],
   "source": [
    "print('max depth: ', crossvalidation_mod.bestModel._java_obj.getMaxDepth(), \"\\n\",\n",
    "     'min information gain: ', crossvalidation_mod.bestModel._java_obj.getMinInfoGain())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
